import numpy as np
import random 
import numba

try:
    profile
except:
    def profile(x): 
        return x

class Assignment(object):
    def __init__(self, nLoci, chrMap, phenotypes, qtleffects, assignments = None):
        self.nLoci = nLoci
        self.chrMap = chrMap

        if assignments is None:
            assignments = np.random.binomial(1, .5, size = (2, nLoci))

        self.assignments = assignments # This gives the parental haplotype of origin for each effect.
        self.phenotypes = phenotypes
        self.qtleffects = qtleffects # This is 2 * 2 * nLoci * nQTL, giving the QTL effect of the paternal/maternal allele at each locus.
        self.score = self.evaluate()

    @profile
    def evaluate(self):
        phenotype_pred = getPhenotypes(self.assignments, self.qtleffects)
        distance = np.sum((self.phenotypes - phenotype_pred)**2)
        # print(phenotype_pred, self.phenotypes)
        nRecomb = getRecombinations(self.assignments, self.chrMap)
        recRate = 0.1
        score = -distance + nRecomb*np.log(recRate) + (self.nLoci - nRecomb)*np.log(1-recRate)
        score = -distance
        return score

    @profile
    def mutate(self):
        newAssigment = self.assignments.copy()
        hap = random.randrange(2)
        loci = random.randrange(self.nLoci)
        newAssigment[hap, loci] = 1 - newAssigment[hap, loci] # Just flip a loci.

        mutated = Assignment(self.nLoci, self.chrMap, self.phenotypes, self.qtleffects, newAssigment)
        return mutated

@numba.njit
def getPhenotypes(assignments, qtleffects):
    nLoci = qtleffects.shape[2]
    nQtl = qtleffects.shape[3]
    phenotype_pred = np.full(nQtl, 0, dtype = np.float32)    
    for i in range(2):
        for j in range(nLoci):
            # The haplotype assignment is assignments[i, j]
            # i iterates over parents (sire = 0, dam = 1)
            phenotype_pred += qtleffects[i, assignments[i, j], j]
    return phenotype_pred

@numba.njit
def getRecombinations(assignments, chrMap):
    nLoci = assignments.shape[1]

    nRecomb = 0
    for i in range(2):
        currentHaplotype = assignments[i, 0]
        currentChrom = chrMap[0]
        for j in range(nLoci):
            if currentChrom != chrMap[j]:
                currentChrom = chrMap[j]
                currentHaplotype = assignments[i, j]

            if currentHaplotype != assignments[i, j]:
                currentHaplotype = assignments[i, j]
                nRecomb += 1
    return nRecomb



def getEffects(haplotype, qtls):
    nLoci = haplotype.shape[0]
    effects = np.full((haplotype.shape[0], qtls.shape[0]), 0, dtype = np.float32)
    for i in range(nLoci):
        effects[i,:] = haplotype[i]*qtls[:, i]
    return effects

@profile
def imputeIndividualFromPhenotypes(ind, sire, dam, phenotypes, qtls, chrMap):
    nQtl, nLoci = qtls.shape
    
    qtleffects = np.full((2, 2, nLoci, nQtl), 0, dtype = np.float32)
    qtleffects[0, 0, :, :] = getEffects(sire.haplotypes[0], qtls)
    qtleffects[0, 1, :, :] = getEffects(sire.haplotypes[1], qtls)
    qtleffects[1, 0, :, :] = getEffects(dam.haplotypes[0], qtls)
    qtleffects[1, 1, :, :] = getEffects(dam.haplotypes[1], qtls)

    current_assignment = Assignment(nLoci, chrMap, phenotypes, qtleffects)
    best_assignment = current_assignment

    for j in range(100):
        current_assignment = Assignment(nLoci, chrMap, phenotypes, qtleffects)

        for i in range(1000):
            new_assignment = current_assignment.mutate()

            if new_assignment.score > best_assignment.score:
                best_assignment = new_assignment
                print(ind.idx, new_assignment.score, best_assignment.score)

            if new_assignment.score > current_assignment.score:
                current_assignment = new_assignment
            elif random.random() > 0.7:
                current_assignment = new_assignment

    ind.dosages = np.full(nLoci, 0, dtype = np.float32)

    assign_mat = best_assignment.assignments
    for i in range(nLoci):
        ind.dosages[i] += (1-assign_mat[0,i])*sire.haplotypes[0][i]
        ind.dosages[i] += assign_mat[0,i]*sire.haplotypes[1][i]

        ind.dosages[i] += (1-assign_mat[1,i])*dam.haplotypes[0][i]
        ind.dosages[i] += assign_mat[1,i]*dam.haplotypes[1][i]

# def expNorm(vect):
#     vect -= np.max(vect)
#     tmp = np.exp(vect)
#     return tmp/np.sum(tmp)

# def fitAssignments(patEffects, matEffects, phenotypes, regions):
#     nSamples = 10
#     nRegions = regions.shape[0]

#     # sampleWeights = np.full(nSamples, 0, dtype = np.float32)
#     # patSamples = np.full((nSamples, nRegions), 0, dtype = np.float32)
#     # matSamples = np.full((nSamples, nRegions), 0, dtype = np.float32)

#     matAssign = np.random.binomial(1, .5, size = nRegions)
#     patAssign = np.random.binomial(1, .5, size = nRegions)

#     for i in range(nSamples):
#         score, matAssign, patAssign = sample(matAssign, patAssign, patEffects, matEffects, phenotypes, regions)

#     weights = expNorm(sampleWeights)

#     patOutput = np.full(nRegions, 0, dtype = np.float32)
#     matOutput = np.full(nRegions, 0, dtype = np.float32)
    
#     print(weights)
#     print(patSamples)
#     print(matSamples)
    
#     for i in range(nSamples):
#         patOutput += weights[i] * patSamples[i,:]
#     for i in range(nSamples):
#         matOutput += weights[i] * matSamples[i,:]

#     return patOutput, matOutput

# def sample(matAssign, patAssign, patEffects, matEffects, phenotypes, regions) :
#     nRegions = regions.shape[0]
    
#     matChange = np.random.binomial(1, .1, size = nRegions)
#     patChange = np.random.binomial(1, .1, size = nRegions)
    
#     matAssign = matAssign
#     patAssign = patAssign 
    
#     score = scoreAssignment(matAssign, patAssign, patEffects, matEffects, phenotypes, regions)
#     return (score, matAssign, patAssign)

# def scoreAssignment(matAssign, patAssign, patEffects, matEffects, phenotypes, regions, recRate = .1):
#     nRegions = len(matAssign)
#     nTraits = patEffects[0].shape[1]
#     # Get predicted phenotype.

#     phenotype_pred = np.full(nTraits, 0, dtype = np.float32)
#     for i in range(nRegions):
#         phenotype_pred += matEffects[matAssign[i]][i,:]
#         phenotype_pred += patEffects[patAssign[i]][i,:]

#     distance = np.sum((phenotypes - phenotype_pred)**2)

#     nRec = 0
#     assignment = matAssign[0]
#     chrom = regions[0, 0]
#     for i in range(1, nRegions):
#         if regions[i,0] != chrom:
#             chrom = regions[i,0]
#             assignment = matAssign[i]
#         if assignment != matAssign[i]:
#             nRec += 1
#             assignment = matAssign[i]

#     assignment = patAssign[0]
#     chrom = regions[0, 0]
#     for i in range(1, nRegions):
#         if regions[i,0] != chrom:
#             chrom = regions[i,0]
#             assignment = patAssign[i]
#         if assignment != patAssign[i]:
#             nRec += 1
#             assignment = patAssign[i]

#     return -distance + nRec*np.log(recRate) + (nRegions - nRec)*np.log(1-recRate)



